diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h @@ -16,6 +16,9 @@ namespace LLVM { +/// Create a pass to remove BF16 types from LLVM IR. +std::unique_ptr createSoftwareBF16Pass(); + /// Generate the code for registering conversion passes. #define GEN_PASS_REGISTRATION #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -16,4 +16,18 @@ let constructor = "mlir::LLVM::createLegalizeForExportPass()"; } +def SoftwareBF16 : Pass<"llvm-software-bf16"> { + let summary = "Convert BF16 to I16 in LLVM IR"; + let description = [{ + This pass erases the BF16 type from LLVM IR. + + Some LLVM targets do not support LLVM's `bfloat` type, or only support it + incompletely. To allow using the `bf16` type on such targets, this pass + replaces all of its uses by `i16` and then replaces operations on `bf16` by + extending the 16-bit values into `f32`, then computes the floating-point + operation on the extended value, and then truncates the results. + }]; + let constructor = "mlir::LLVM::createSoftwareBF16Pass()"; +} + #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -31,6 +31,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/FormatVariadic.h" @@ -39,6 +40,7 @@ #include "../GPUCommon/IndexIntrinsicsOpLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" #include "../PassDetail.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" using namespace mlir; @@ -91,6 +93,11 @@ configureGpuToROCDLConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); + + OpPassManager pm("gpu.module"); + pm.addPass(LLVM::createSoftwareBF16Pass()); + if (failed(runPipeline(pm, getOperation()))) + signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms LegalizeForExport.cpp + SoftwareBf16.cpp DEPENDS MLIRLLVMPassIncGen @@ -8,4 +9,5 @@ MLIRIR MLIRLLVMIR MLIRPass + MLIRLLVMCommonConversion ) diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/SoftwareBf16.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/SoftwareBf16.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/Transforms/SoftwareBf16.cpp @@ -0,0 +1,384 @@ +//===- SoftwareBf16.cpp - Prepare for translation to LLVM IR ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +static APInt castBF16toInt(APFloat value) { + assert(&value.getSemantics() == &APFloat::BFloat() && "Must cast bf16 only"); + APInt ret = value.bitcastToAPInt(); + assert(ret.getBitWidth() == 16 && "bf16 conversion should make i16"); + return ret; +} + +static Value getLlvmI32Const(Location loc, PatternRewriter &rewriter, Type type, + int32_t value) { + Attribute ret = rewriter.getI32IntegerAttr(value); + if (LLVM::isCompatibleVectorType(type)) + ret = SplatElementsAttr::get(type.cast(), ret); + return rewriter.createOrFold(loc, type, ret); +} + +namespace { +/// Rewrites bf16 constants to their i16 equivalents +/// This is relying on the fact that the vector, i16, and bf16 types used in the +/// LLVM dialect are the standard ones and not weird custom wrappers +struct BF16ConstCasting : OpRewritePattern { + explicit BF16ConstCasting(MLIRContext *context) : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(LLVM::ConstantOp op, + PatternRewriter &rewriter) const override { + Attribute val = op.getValueAttr(); + Operation *rawOp = op.getOperation(); + Type bf16 = rewriter.getBF16Type(); + Type retType = op.getRes().getType(); + Type retElemType = retType; + + if (auto retTypeShaped = retType.dyn_cast()) + retElemType = retTypeShaped.getElementType(); + + if (auto valFloat = val.dyn_cast()) { + if (valFloat.getType() != bf16) + return failure(); + APInt newVal = castBF16toInt(valFloat.getValue()); + rewriter.replaceOpWithNewOp( + rawOp, retType, rewriter.getIntegerAttr(retType, newVal)); + return success(); + } + + if (auto valDense = val.dyn_cast()) { + if (valDense.getElementType() != bf16) + return failure(); + DenseElementsAttr newVal = valDense.bitcast(retElemType); + rewriter.replaceOpWithNewOp(rawOp, retType, newVal); + return success(); + } + + if (auto valSparse = val.dyn_cast()) { + if (valSparse.getElementType() != bf16) + return failure(); + DenseElementsAttr values = valSparse.getValues(); + DenseElementsAttr newValues = values.bitcast(retElemType); + auto newVal = SparseElementsAttr::get(retType.cast(), + valSparse.getIndices(), newValues); + rewriter.replaceOpWithNewOp(rawOp, retType, newVal); + return success(); + } + // No match otherwise + return failure(); + } +}; + +template +struct BF16AsF32 : OpRewritePattern { + explicit BF16AsF32(MLIRContext *context) : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(Op op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Type opdType = op->getOperand(0).getType(); + Type opdElementType = opdType; + Type i16 = rewriter.getIntegerType(16); + + if (auto opdShaped = opdType.dyn_cast()) { + opdElementType = opdShaped.getElementType(); + } + + Type resType = op.getResult().getType(); + Type extType = rewriter.getF32Type(); + Type resElementType = resType; + + if (auto resShaped = resType.dyn_cast()) { + extType = resShaped.clone(extType); + resElementType = resShaped.getElementType(); + } + + if (resElementType != i16 && opdElementType != i16) + return failure(); + + llvm::SmallVector extended; + if (isa(op) || isa(op)) { + extended.push_back(op->getOperand(0)); + } else { + for (Value v : op->getOperands()) { + extended.push_back( + rewriter.create(loc, extType, v)); // i16->f32 + } + } + + if (resElementType == i16) { + Op operation = + rewriter.create(loc, extType, extended, op->getAttrs()); + rewriter.replaceOpWithNewOp(op, resType, + operation.getResult()); + } else { // FCmp + rewriter.replaceOpWithNewOp(op, resType, extended, op->getAttrs()); + } + + return success(); + } +}; + +struct SoftwareBF16Ext : OpRewritePattern { + explicit SoftwareBF16Ext(MLIRContext *context) : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(LLVM::FPExtOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Type srcType = op.getArg().getType(); + Type destType = op.getResult().getType(); + Type srcElemType = srcType; + if (auto shaped = srcType.dyn_cast()) + srcElemType = shaped.getElementType(); + + Type i16 = rewriter.getIntegerType(16); + if (srcElemType != i16) + return failure(); + + Type extType = rewriter.getI32Type(); + if (auto srcShaped = srcType.dyn_cast()) + extType = srcShaped.clone(extType); + + Type f32 = rewriter.getF32Type(); + if (auto destShaped = destType.dyn_cast()) { + if (destShaped.getElementType() != f32) + return failure(); + } else if (destType != f32) + return failure(); + + Value extended = rewriter.create(loc, extType, op.getArg()); + Value shifted = rewriter.create( + loc, extended, getLlvmI32Const(loc, rewriter, extType, 16)); + rewriter.replaceOpWithNewOp(op, destType, shifted); + + return success(); + } +}; + +/// Rewrites truncation to bfloat as a series of integer operations. +struct SoftwareBF16Trunc : OpRewritePattern { + explicit SoftwareBF16Trunc(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(LLVM::FPTruncOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + + Type srcType = op.getArg().getType(); + Type destType = op.getRes().getType(); + Type srcElemType = srcType; + if (auto shaped = srcType.dyn_cast()) + srcElemType = shaped.getElementType(); + + Type f32 = rewriter.getF32Type(); + if (srcElemType != f32) + return failure(); + + Type bitcastType = rewriter.getI32Type(); + if (auto srcShaped = srcType.dyn_cast()) + bitcastType = srcShaped.clone(bitcastType); + + Type i16 = rewriter.getIntegerType(16); + if (auto destShaped = destType.dyn_cast()) { + if (destShaped.getElementType() != i16) + return failure(); + } else if (destType != i16) + return failure(); + + // a = bitcast f32 value to i32 + // b = (a + 32767) >> 16 + // c = ((a >> 16) & 1) + // d = b + c + // truncate (d << 16) to i16 and return this i16 + Value bitcastop = + rewriter.create(loc, bitcastType, op.getArg()); + + Value constantSixteen = getLlvmI32Const(loc, rewriter, bitcastType, 16); + Value shiftValue = rewriter.create( + loc, bitcastType, bitcastop, constantSixteen); + + Value constantOne = getLlvmI32Const(loc, rewriter, bitcastType, 1); + Value andValue = rewriter.create(loc, shiftValue, constantOne); + + Value constantBig = getLlvmI32Const(loc, rewriter, bitcastType, 32767); + Value addBigValue = + rewriter.create(loc, bitcastop, constantBig); + Value shiftBigValue = rewriter.create( + loc, bitcastType, addBigValue, constantSixteen); + + Value addValue = rewriter.create(loc, andValue, shiftBigValue); + + Value shiftBeforeTruncValue = rewriter.create( + loc, bitcastType, addValue, constantSixteen); + Value truncValue = + rewriter.create(loc, destType, shiftBeforeTruncValue); + rewriter.replaceOp(op.getOperation(), {truncValue}); + + return success(); + } +}; + +} // namespace + +static void replaceBF16WithI16(Operation *op, TypeConverter &converter) { + if (auto func = dyn_cast(op)) { + auto funcType = func.getFunctionType(); + func.setType(converter.convertType(funcType)); + for (Value arg : func.getArguments()) + arg.setType(converter.convertType(arg.getType())); + } else if (auto globalOp = dyn_cast(op)) { + Type globalType = globalOp.getType(); + globalOp.setGlobalTypeAttr( + TypeAttr::get(converter.convertType(globalType))); + } else { + for (unsigned idx = 0; idx < op->getNumOperands(); idx++) { + auto type = converter.convertType(op->getOperand(idx).getType()); + op->getOperand(idx).setType(type); + } + for (unsigned idx = 0; idx < op->getNumResults(); idx++) { + auto type = converter.convertType(op->getResult(idx).getType()); + op->getResult(idx).setType(type); + } + } + return; +} + +static void populateSoftwareBF16Patterns(MLIRContext *ctx, + TypeConverter &converter, + RewritePatternSet &patterns) { + Type llvmI16 = IntegerType::get(ctx, 16); + + converter.addConversion([](Type type) { return type; }); + + converter.addConversion( + [llvmI16](BFloat16Type type) -> Type { return llvmI16; }); + + converter.addConversion([&](VectorType type) -> Optional { + if (auto element = converter.convertType(type.getElementType())) + return type.clone(element); + return llvm::None; + }); + + converter.addConversion( + [&](LLVM::LLVMPointerType type) -> llvm::Optional { + if (auto pointee = converter.convertType(type.getElementType())) + return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace()); + return llvm::None; + }); + + converter.addConversion( + [&](LLVM::LLVMStructType type, SmallVectorImpl &results, + ArrayRef callStack) -> Optional { + bool converted = false; + SmallVector convertedElemTypes; + convertedElemTypes.reserve(type.getBody().size()); + for (auto t : type.getBody()) { + SmallVector element; + if (failed(converter.convertType(t, element))) + return llvm::None; + assert(element.size() == 1); + convertedElemTypes.push_back(element[0]); + if (t != element[0]) + converted = true; + } + + if (!converted) { + results.push_back(type); + return success(); + } + + // Identified StructType + if (type.isIdentified()) { + auto convertedType = LLVM::LLVMStructType::getIdentified( + type.getContext(), ("_Converted_" + type.getName()).str()); + unsigned counter = 1; + while (convertedType.isInitialized()) { + convertedType = LLVM::LLVMStructType::getIdentified( + type.getContext(), + ("_Converted_" + Twine(counter++) + type.getName()).str()); + } + if (llvm::count(callStack, type) > 1) { + results.push_back(convertedType); + return success(); + } + if (failed( + convertedType.setBody(convertedElemTypes, type.isPacked()))) + return llvm::None; + results.push_back(convertedType); + return success(); + } + + // Literal StructType + results.push_back(LLVM::LLVMStructType::getLiteral( + type.getContext(), convertedElemTypes, type.isPacked())); + return success(); + }); + + converter.addConversion( + [&](LLVM::LLVMArrayType type) -> llvm::Optional { + if (auto element = converter.convertType(type.getElementType())) + return LLVM::LLVMArrayType::get(element, type.getNumElements()); + return llvm::None; + }); + + converter.addConversion( + [&](LLVM::LLVMFunctionType type) -> llvm::Optional { + Type convertedResType = converter.convertType(type.getReturnType()); + if (!convertedResType) + return llvm::None; + + SmallVector convertedArgTypes; + convertedArgTypes.reserve(type.getNumParams()); + if (failed(converter.convertTypes(type.getParams(), convertedArgTypes))) + return llvm::None; + return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes, + type.isVarArg()); + }); + + patterns.add(ctx); + + patterns.add, BF16AsF32, + BF16AsF32, BF16AsF32, + BF16AsF32, BF16AsF32, + BF16AsF32, BF16AsF32, + BF16AsF32, BF16AsF32, + BF16AsF32, BF16AsF32, + BF16AsF32, BF16AsF32, + BF16AsF32, BF16AsF32>(ctx); +} + +namespace { +struct SoftwareBF16Pass : public SoftwareBF16Base { + void runOnOperation() override { + auto m = getOperation(); + MLIRContext *ctx = m->getContext(); + TypeConverter converter; + RewritePatternSet bf16fixupPatterns(ctx); + + populateSoftwareBF16Patterns(ctx, converter, bf16fixupPatterns); + // Replace BF16 types in an operation with I16 types + m->walk([&converter](Operation *op) { replaceBF16WithI16(op, converter); }); + + if (failed(applyPatternsAndFoldGreedily(m, std::move(bf16fixupPatterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr LLVM::createSoftwareBF16Pass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir b/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir --- a/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir +++ b/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir @@ -22,25 +22,13 @@ // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] - // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] - // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] - // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 - // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] - // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : i64 - // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] - // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 - // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] - // "Store" lowering should work just as any other memref, only check that // we emit some core instructions. // NVVM: llvm.extractvalue %[[descr6:.*]] // NVVM: llvm.getelementptr // NVVM: llvm.store - // ROCDL: llvm.extractvalue %[[descr6:.*]] - // ROCDL: llvm.getelementptr - // ROCDL: llvm.store + // ROCDL: llvm.store {{.*}}, %[[raw]] %c0 = arith.constant 0 : index memref.store %arg0, %arg1[%c0] : memref<4xf32, 5> @@ -88,25 +76,13 @@ // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] - // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] - // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] - // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 - // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] - // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : i64 - // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] - // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 - // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] - // "Store" lowering should work just as any other memref, only check that // we emit some core instructions. // NVVM: llvm.extractvalue %[[descr6:.*]] // NVVM: llvm.getelementptr // NVVM: llvm.store - // ROCDL: llvm.extractvalue %[[descr6:.*]] - // ROCDL: llvm.getelementptr - // ROCDL: llvm.store + // ROCDL: llvm.store {{.*}}, %[[raw]] %c0 = arith.constant 0 : index memref.store %arg0, %arg1[%c0] : memref<4xf32, 3> @@ -159,23 +135,8 @@ // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 // NVVM: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2] - // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] - // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] - // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 - // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] - // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : i64 - // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] - // ROCDL: %[[c12:.*]] = llvm.mlir.constant(12 : index) : i64 - // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0] - // ROCDL: %[[c2:.*]] = llvm.mlir.constant(2 : index) : i64 - // ROCDL: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1] - // ROCDL: %[[c6:.*]] = llvm.mlir.constant(6 : index) : i64 - // ROCDL: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1] - // ROCDL: %[[c6:.*]] = llvm.mlir.constant(6 : index) : i64 - // ROCDL: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2] - // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 - // ROCDL: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2] + // ROCDL: %[[offset:.*]] = llvm.getelementptr %[[raw]] + // ROCDL: llvm.store {{.*}} %[[offset]] %c0 = arith.constant 0 : index memref.store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, 3> @@ -203,6 +164,9 @@ workgroup(%arg1: memref<1xf32, 3>, %arg2: memref<2xf32, 3>) private(%arg3: memref<3xf32, 5>, %arg4: memref<4xf32, 5>) { + // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : i64) + // ROCDL: %[[c3:.*]] = llvm.mlir.constant(3 : i64) + // Workgroup buffers. // NVVM: llvm.mlir.addressof @[[$buffer1]] // NVVM: llvm.mlir.addressof @[[$buffer2]] @@ -216,9 +180,7 @@ // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : i64) // NVVM: llvm.alloca %[[c4]] x f32 : (i64) -> !llvm.ptr - // ROCDL: %[[c3:.*]] = llvm.mlir.constant(3 : i64) // ROCDL: llvm.alloca %[[c3]] x f32 : (i64) -> !llvm.ptr - // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : i64) // ROCDL: llvm.alloca %[[c4]] x f32 : (i64) -> !llvm.ptr %c0 = arith.constant 0 : index diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-hip.mlir @@ -9,14 +9,12 @@ // CHECK-LABEL: func @test_const_printf gpu.func @test_const_printf() { + // CHECK: %[[ISLAST:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[FORMATLEN:.*]] = llvm.mlir.constant(14 : i64) : i64 // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 - // CHECK-NEXT: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin(%0) : (i64) -> i64 - // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr> - // CHECK-NEXT: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64 - // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][%[[CST1]], %[[CST1]]] : (!llvm.ptr>, i64, i64) -> !llvm.ptr - // CHECK-NEXT: %[[FORMATLEN:.*]] = llvm.mlir.constant(14 : i64) : i64 - // CHECK-NEXT: %[[ISLAST:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-NEXT: %[[ISNTLAST:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin({{.*}}) : (i64) -> i64 + // CHECK: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr> + // CHECK: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][%[[CST0]], %[[CST0]]] : (!llvm.ptr>, i64, i64) -> !llvm.ptr // CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_string_n(%[[DESC0]], %[[FORMATSTART]], %[[FORMATLEN]], %[[ISLAST]]) : (i64, !llvm.ptr, i64, i32) -> i64 gpu.printf "Hello, world\n" gpu.return @@ -26,18 +24,16 @@ // CHECK-LABEL: func @test_printf // CHECK: (%[[ARG0:.*]]: i32) gpu.func @test_printf(%arg0: i32) { + // CHECK: %[[ISNTLAST:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[NARGS1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[FORMATLEN:.*]] = llvm.mlir.constant(11 : i64) : i64 // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 - // CHECK-NEXT: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin(%0) : (i64) -> i64 - // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr> - // CHECK-NEXT: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64 - // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][%[[CST1]], %[[CST1]]] : (!llvm.ptr>, i64, i64) -> !llvm.ptr - // CHECK-NEXT: %[[FORMATLEN:.*]] = llvm.mlir.constant(11 : i64) : i64 - // CHECK-NEXT: %[[ISLAST:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK-NEXT: %[[ISNTLAST:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[DESC0:.*]] = llvm.call @__ockl_printf_begin({{.*}}) : (i64) -> i64 + // CHECK: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr> + // CHECK: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][%[[CST0]], %[[CST0]]] : (!llvm.ptr>, i64, i64) -> !llvm.ptr // CHECK-NEXT: %[[DESC1:.*]] = llvm.call @__ockl_printf_append_string_n(%[[DESC0]], %[[FORMATSTART]], %[[FORMATLEN]], %[[ISNTLAST]]) : (i64, !llvm.ptr, i64, i32) -> i64 - // CHECK-NEXT: %[[NARGS1:.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-NEXT: %[[ARG0_64:.*]] = llvm.zext %[[ARG0]] : i32 to i64 - // CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_args(%[[DESC1]], %[[NARGS1]], %[[ARG0_64]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[ISLAST]]) : (i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64 + // CHECK-NEXT: %{{.*}} = llvm.call @__ockl_printf_append_args(%[[DESC1]], %[[NARGS1]], %[[ARG0_64]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]], %[[NARGS1]]) : (i64, i32, i64, i64, i64, i64, i64, i64, i64, i32) -> i64 gpu.printf "Hello: %d\n" %arg0 : i32 gpu.return } diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-opencl.mlir @@ -6,8 +6,8 @@ // CHECK-LABEL: func @test_printf // CHECK: (%[[ARG0:.*]]: i32) gpu.func @test_printf(%arg0: i32) { + // CHECK: %[[IMM1:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr, 4> - // CHECK-NEXT: %[[IMM1:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][%[[IMM1]], %[[IMM1]]] : (!llvm.ptr, 4>, i64, i64) -> !llvm.ptr // CHECK-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]], %[[ARG0]]) : (!llvm.ptr, i32) -> i32 gpu.printf "Hello: %d\n" %arg0 : i32 diff --git a/mlir/test/Conversion/SoftwareBF16/softwareBF16.mlir b/mlir/test/Conversion/SoftwareBF16/softwareBF16.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SoftwareBF16/softwareBF16.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt --llvm-software-bf16 %s| FileCheck %s + +module attributes {llvm.data_layout = ""} { + llvm.func @verify_bf16_f32(%arg0: bf16, %arg1: f32) -> i32 { +//CHECK: llvm.func @verify_bf16_f32(%arg0: i16, %arg1: f32) -> i32 { + + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.mlir.constant(1 : i32) : i32 +//CHECK-DAG: %[[V0:.*]] = llvm.mlir.constant(32767 : i32) : i32 +//CHECK-DAG: %[[V1:.*]] = llvm.mlir.constant(16 : i32) : i32 +//CHECK-DAG: %[[V2:.*]] = llvm.mlir.constant(1 : i32) : i32 +//CHECK-DAG: %[[V3:.*]] = llvm.mlir.constant(0 : i32) : i32 + + %2 = llvm.mlir.constant(1.503910e-01 : bf16) : bf16 +//CHECK-DAG: %[[V4:.*]] = llvm.mlir.constant(15898 : i16) : i16 + + %3 = llvm.fptrunc %arg1 : f32 to bf16 +//CHECK: %[[V5:.*]] = llvm.bitcast %arg1 : f32 to i32 +//CHECK-NEXT: %[[V6:.*]] = llvm.lshr %[[V5]], %[[V1]] : i32 +//CHECK-NEXT: %[[V7:.*]] = llvm.and %[[V6]], %[[V2]] : i32 +//CHECK-NEXT: %[[V8:.*]] = llvm.add %[[V5]], %[[V0]] : i32 +//CHECK-NEXT: %[[V9:.*]] = llvm.lshr %[[V8]], %[[V1]] : i32 +//CHECK-NEXT: %[[V10:.*]] = llvm.add %[[V7]], %[[V9]] : i32 +//CHECK-NEXT: %[[V11:.*]] = llvm.lshr %[[V10]], %[[V1]] : i32 +//CHECK-NEXT: %[[V12:.*]] = llvm.trunc %[[V11]] : i32 to i16 + + %4 = llvm.fsub %arg0, %3 : bf16 +//CHECK: %[[V13:.*]] = llvm.zext %arg0 : i16 to i32 +//CHECK-NEXT: %[[V14:.*]] = llvm.shl %[[V13]], %[[V1]] : i32 +//CHECK-NEXT: %[[V15:.*]] = llvm.bitcast %[[V14]] : i32 to f32 +//CHECK: %[[V16:.*]] = llvm.zext %[[V12]] : i16 to i32 +//CHECK-NEXT: %[[V17:.*]] = llvm.shl %[[V16]], %1 : i32 +//CHECK-NEXT: %[[V18:.*]] = llvm.bitcast %[[V17]] : i32 to f32 +//CHECK: %[[V19:.*]] = llvm.fsub %[[V15]], %[[V18]] : f32 +//CHECK-NEXT: %[[V20:.*]] = llvm.bitcast %[[V19]] : f32 to i32 +//CHECK-NEXT: %[[V21:.*]] = llvm.lshr %[[V20]], %[[V1]] : i32 +//CHECK-NEXT: %[[V22:.*]] = llvm.and %[[V21]], %[[V2]] : i32 +//CHECK-NEXT: %[[V23:.*]] = llvm.add %[[V20]], %[[V0]] : i32 +//CHECK-NEXT: %[[V24:.*]] = llvm.lshr %[[V23]], %[[V1]] : i32 +//CHECK-NEXT: %[[V25:.*]] = llvm.add %[[V22]], %[[V24]] : i32 +//CHECK-NEXT: %[[V26:.*]] = llvm.lshr %[[V25]], %[[V1]] : i32 +//CHECK-NEXT: %[[V27:.*]] = llvm.trunc %[[V26]] : i32 to i16 + + %5 = llvm.fcmp "ugt" %4, %2 : bf16 +//CHECK: %[[V28:.*]] = llvm.zext %[[V27]] : i16 to i32 +//CHECK-NEXT: %[[V29:.*]] = llvm.shl %[[V28]], %[[V1]] : i32 +//CHECK-NEXT: %[[V30:.*]] = llvm.bitcast %[[V29]] : i32 to f32 +//CHECK: %[[V31:.*]] = llvm.zext %[[V4]] : i16 to i32 +//CHECK-NEXT: %[[V32:.*]] = llvm.shl %[[V31]], %[[V1]] : i32 +//CHECK-NEXT: %[[V33:.*]] = llvm.bitcast %[[V32]] : i32 to f32 +//CHECK: %{{.*}} = llvm.fcmp "ugt" %[[V30]], %[[V33]] : f32 + + llvm.cond_br %5, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + llvm.return %1 : i32 + ^bb2: // pred: ^bb0 + llvm.return %0 : i32 + } +}