diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -38,6 +38,9 @@ /// Add patterns to expand Arith ceil/floor division ops. void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); +/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts. +void populateExpandBFloat16Patterns(RewritePatternSet &patterns); + /// Add patterns to expand Arith ops. void populateArithExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -32,6 +32,10 @@ let summary = "Legalize Arith ops to be convertible to LLVM."; let constructor = "mlir::arith::createArithExpandOpsPass()"; let dependentDialects = ["vector::VectorDialect"]; + let options = [ + Option<"includeBf16", "include-bf16", "bool", /*default=*/"false", + "Enable the BF16 expansion patterns">, + ]; } def ArithUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> { diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -25,15 +26,13 @@ /// Create an integer or index constant. static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter) { - - auto elTy = getElementTypeOrSelf(type); - auto constantAttr = rewriter.getIntegerAttr(elTy, value); - - if (auto vecTy = llvm::dyn_cast(type)) + auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value); + if (auto shapedTy = dyn_cast(type)) { return rewriter.create( - loc, vecTy, DenseElementsAttr::get(vecTy, constantAttr)); + loc, DenseElementsAttr::get(shapedTy, attr)); + } - return rewriter.create(loc, constantAttr); + return rewriter.create(loc, attr); } namespace { @@ -187,6 +186,122 @@ } }; +struct BFloat16ExtFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + if (!operandETy.isBF16() || !resultETy.isF32()) { + return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32."); + } + + Type i16Ty = b.getI16Type(); + Type i32Ty = b.getI32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i16Ty = shapedTy.clone(i16Ty); + i32Ty = shapedTy.clone(i32Ty); + } + + Value bitcast = b.create(i16Ty, operand); + Value exti = b.create(i32Ty, bitcast); + + Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); + Value shl = b.create(exti, c16); + Value result = b.create(resultTy, shl); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct BFloat16TruncFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + if (!operandETy.isF32() || !resultETy.isBF16()) { + return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16."); + } + + Type i1Ty = b.getI1Type(); + Type i16Ty = b.getI16Type(); + Type i32Ty = b.getI32Type(); + Type f32Ty = b.getF32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i1Ty = shapedTy.clone(i1Ty); + i16Ty = shapedTy.clone(i16Ty); + i32Ty = shapedTy.clone(i32Ty); + f32Ty = shapedTy.clone(f32Ty); + } + + Value bitcast = b.create(i32Ty, operand); + + Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter); + Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter); + Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter); + Value expMask = + createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter); + Value expMax = + createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter); + + // Grab the sign bit. + Value sign = b.create(bitcast, c31); + + // Our mantissa rounding value depends on the sign bit and the last + // truncated bit. + Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter); + cManRound = b.create(cManRound, sign); + + // Grab out the mantissa and directly apply rounding. + Value man = b.create(bitcast, c23Mask); + Value manRound = b.create(man, cManRound); + + // Grab the overflow bit and shift right if we overflow. + Value roundBit = b.create(manRound, c23); + Value manNew = b.create(manRound, roundBit); + + // Grab the exponent and round using the mantissa's carry bit. + Value exp = b.create(bitcast, expMask); + Value expCarry = b.create(exp, manRound); + expCarry = b.create(expCarry, expMask); + + // If the exponent is saturated, we keep the max value. + Value expCmp = + b.create(arith::CmpIPredicate::uge, exp, expMax); + exp = b.create(expCmp, exp, expCarry); + + // If the exponent is max and we rolled over, keep the old mantissa. + Value roundBitBool = b.create(i1Ty, roundBit); + Value keepOldMan = b.create(expCmp, roundBitBool); + man = b.create(keepOldMan, man, manNew); + + // Assemble the now rounded f32 value (as an i32). + Value rounded = b.create(sign, c31); + rounded = b.create(rounded, exp); + rounded = b.create(rounded, man); + + Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); + Value shr = b.create(rounded, c16); + Value trunc = b.create(i16Ty, shr); + Value result = b.create(resultTy, trunc); + + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsBase { void runOnOperation() override { @@ -204,6 +319,24 @@ arith::MaxFOp, arith::MinFOp >(); + + if (includeBf16) { + arith::populateExpandBFloat16Patterns(patterns); + target.addDynamicallyLegalOp( + [](arith::ExtFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + return !(inETy.isBF16() && outETy.isF32()); + }); + + target.addDynamicallyLegalOp( + [](arith::TruncFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + return !(inETy.isF32() && outETy.isBF16()); + }); + } + // clang-format on if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -220,6 +353,11 @@ patterns.getContext()); } +void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); // clang-format off diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -1,217 +1,48 @@ -// RUN: mlir-opt %s -arith-expand -split-input-file | FileCheck %s - -// Test ceil divide with signed integer -// CHECK-LABEL: func @ceildivi -// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { -func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) { - %res = arith.ceildivsi %arg0, %arg1 : i32 - return %res : i32 - -// CHECK: [[ONE:%.+]] = arith.constant 1 : i32 -// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32 -// CHECK: [[MINONE:%.+]] = arith.constant -1 : i32 -// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32 -// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32 -// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32 -// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32 -// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32 -// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32 -// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32 -// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32 -// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32 -// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1 -// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1 -// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32 -} - -// ----- - -// Test ceil divide with index type -// CHECK-LABEL: func @ceildivi_index -// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index { -func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) { - %res = arith.ceildivsi %arg0, %arg1 : index - return %res : index - -// CHECK: [[ONE:%.+]] = arith.constant 1 : index -// CHECK: [[ZERO:%.+]] = arith.constant 0 : index -// CHECK: [[MINONE:%.+]] = arith.constant -1 : index -// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index -// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index -// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index -// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index -// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index -// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index -// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index -// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index -// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index -// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index -// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index -// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index -// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1 -// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1 -// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index -} - -// ----- - -// Test floor divide with signed integer -// CHECK-LABEL: func @floordivi -// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { -func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) { - %res = arith.floordivsi %arg0, %arg1 : i32 - return %res : i32 -// CHECK: [[ONE:%.+]] = arith.constant 1 : i32 -// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32 -// CHECK: [[MIN1:%.+]] = arith.constant -1 : i32 -// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : i32 -// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : i32 -// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32 -// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : i32 -// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32 -// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32 -// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32 -// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32 -// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1 -// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1 -// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32 -} - -// ----- - -// Test floor divide with index type -// CHECK-LABEL: func @floordivi_index -// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index { -func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) { - %res = arith.floordivsi %arg0, %arg1 : index - return %res : index -// CHECK: [[ONE:%.+]] = arith.constant 1 : index -// CHECK: [[ZERO:%.+]] = arith.constant 0 : index -// CHECK: [[MIN1:%.+]] = arith.constant -1 : index -// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index -// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : index -// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index -// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index -// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index -// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index -// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index -// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index -// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index -// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index -// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1 -// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1 -// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 -// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index -} - -// ----- - -// Test floor divide with vector -// CHECK-LABEL: func.func @floordivi_vec( -// CHECK-SAME: %[[VAL_0:.*]]: vector<4xi32>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<4xi32>) -> vector<4xi32> { -func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) { - %res = arith.floordivsi %arg0, %arg1 : vector<4xi32> - return %res : vector<4xi32> -// CHECK: %[[VAL_2:.*]] = arith.constant dense<1> : vector<4xi32> -// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4xi32> -// CHECK: %[[VAL_4:.*]] = arith.constant dense<-1> : vector<4xi32> -// CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_2]], %[[VAL_4]] : vector<4xi1>, vector<4xi32> -// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_6]], %[[VAL_0]] : vector<4xi32> -// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_1]] : vector<4xi32> -// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_4]], %[[VAL_8]] : vector<4xi32> -// CHECK: %[[VAL_10:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32> -// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_12:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32> -// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_11]], %[[VAL_14]] : vector<4xi1> -// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : vector<4xi1> -// CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_15]], %[[VAL_16]] : vector<4xi1> -// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_10]] : vector<4xi1>, vector<4xi32> -} - -// ----- - -// Test ceil divide with unsigned integer -// CHECK-LABEL: func @ceildivui -// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { -func.func @ceildivui(%arg0: i32, %arg1: i32) -> (i32) { - %res = arith.ceildivui %arg0, %arg1 : i32 - return %res : i32 -// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32 -// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : i32 -// CHECK: [[ONE:%.+]] = arith.constant 1 : i32 -// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32 -// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32 -// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32 -// CHECK: [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : i32 -} - -// ----- - -// Test unsigned ceil divide with index -// CHECK-LABEL: func @ceildivui_index -// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index { -func.func @ceildivui_index(%arg0: index, %arg1: index) -> (index) { - %res = arith.ceildivui %arg0, %arg1 : index - return %res : index -// CHECK: [[ZERO:%.+]] = arith.constant 0 : index -// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : index -// CHECK: [[ONE:%.+]] = arith.constant 1 : index -// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index -// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index -// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index -// CHECK: [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : index -} - -// ----- - -// CHECK-LABEL: func @maxf -func.func @maxf(%a: f32, %b: f32) -> f32 { - %result = arith.maxf %a, %b : f32 - return %result : f32 -} -// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 -// CHECK-NEXT: return %[[RESULT]] : f32 - -// ----- - -// CHECK-LABEL: func @maxf_vector -func.func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> { - %result = arith.maxf %a, %b : vector<4xf16> - return %result : vector<4xf16> -} -// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] -// CHECK-NEXT: return %[[RESULT]] : vector<4xf16> - -// ----- - -// CHECK-LABEL: func @minf -func.func @minf(%a: f32, %b: f32) -> f32 { - %result = arith.minf %a, %b : f32 - return %result : f32 -} -// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 -// CHECK-NEXT: return %[[RESULT]] : f32 +// RUN: mlir-opt %s -arith-expand="include-bf16=true" --canonicalize -split-input-file | FileCheck %s + +func.func @truncf_f32(%arg0 : f32) -> bf16 { + %0 = arith.truncf %arg0 : f32 to bf16 + return %0 : bf16 +} + +// CHECK-LABEL: @truncf_f32 + +// CHECK: %[[C16:.+]] = arith.constant 16 +// CHECK: %[[C32768:.+]] = arith.constant 32768 +// CHECK: %[[C2130706432:.+]] = arith.constant 2130706432 +// CHECK: %[[C2139095040:.+]] = arith.constant 2139095040 +// CHECK: %[[C8388607:.+]] = arith.constant 8388607 +// CHECK: %[[C31:.+]] = arith.constant 31 +// CHECK: %[[C23:.+]] = arith.constant 23 +// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 +// CHECK: %[[SIGN:.+]] = arith.shrui %[[BITCAST:.+]], %[[C31]] +// CHECK: %[[ROUND:.+]] = arith.subi %[[C32768]], %[[SIGN]] +// CHECK: %[[MANTISSA:.+]] = arith.andi %[[BITCAST]], %[[C8388607]] +// CHECK: %[[ROUNDED:.+]] = arith.addi %[[MANTISSA]], %[[ROUND]] +// CHECK: %[[ROLL:.+]] = arith.shrui %[[ROUNDED]], %[[C23]] +// CHECK: %[[SHR:.+]] = arith.shrui %[[ROUNDED]], %[[ROLL]] +// CHECK: %[[EXP:.+]] = arith.andi %0, %[[C2139095040]] +// CHECK: %[[EXPROUND:.+]] = arith.addi %[[EXP]], %[[ROUNDED]] +// CHECK: %[[EXPROLL:.+]] = arith.andi %[[EXPROUND]], %[[C2139095040]] +// CHECK: %[[EXPMAX:.+]] = arith.cmpi uge, %[[EXP]], %[[C2130706432]] +// CHECK: %[[EXPNEW:.+]] = arith.select %[[EXPMAX]], %[[EXP]], %[[EXPROLL]] +// CHECK: %[[OVERFLOW_B:.+]] = arith.trunci %[[ROLL]] +// CHECK: %[[KEEP_MAN:.+]] = arith.andi %[[EXPMAX]], %[[OVERFLOW_B]] +// CHECK: %[[MANNEW:.+]] = arith.select %[[KEEP_MAN]], %[[MANTISSA]], %[[SHR]] +// CHECK: %[[NEWSIGN:.+]] = arith.shli %[[SIGN]], %[[C31]] +// CHECK: %[[WITHEXP:.+]] = arith.ori %[[NEWSIGN]], %[[EXPNEW]] +// CHECK: %[[WITHMAN:.+]] = arith.ori %[[WITHEXP]], %[[MANNEW]] +// CHECK: %[[SHIFT:.+]] = arith.shrui %[[WITHMAN]], %[[C16]] +// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFT]] +// CHECK: %[[RES:.+]] = arith.bitcast %[[TRUNC]] +// CHECK: return %[[RES]] + +// ----- + +func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> { + %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xbf16> + return %0 : vector<4xbf16> +} + +// CHECK-LABEL: @truncf_vector_f32 +// CHECK-NOT: arith.truncf diff --git a/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/expand-arith-ops.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(arith-expand{include-bf16=true},convert-arith-to-llvm),convert-vector-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_c_runner_utils \ +// RUN: -shared-libs=%mlir_runner_utils \ +// RUN: | FileCheck %s + +func.func @trunc_bf16(%a : f32) { + %b = arith.truncf %a : f32 to bf16 + %c = arith.extf %b : bf16 to f32 + vector.print %c : f32 + return +} + +func.func @main() { + // CHECK: 1.00781 + %roundOneI = arith.constant 0x3f808000 : i32 + %roundOneF = arith.bitcast %roundOneI : i32 to f32 + call @trunc_bf16(%roundOneF): (f32) -> () + + // CHECK-NEXT: -1 + %noRoundNegOneI = arith.constant 0xbf808000 : i32 + %noRoundNegOneF = arith.bitcast %noRoundNegOneI : i32 to f32 + call @trunc_bf16(%noRoundNegOneF): (f32) -> () + + // CHECK-NEXT: -1.00781 + %roundNegOneI = arith.constant 0xbf808001 : i32 + %roundNegOneF = arith.bitcast %roundNegOneI : i32 to f32 + call @trunc_bf16(%roundNegOneF): (f32) -> () + + // CHECK-NEXT: inf + %infi = arith.constant 0x7f800000 : i32 + %inff = arith.bitcast %infi : i32 to f32 + call @trunc_bf16(%inff): (f32) -> () + + // CHECK-NEXT: -inf + %neginfi = arith.constant 0xff800000 : i32 + %neginff = arith.bitcast %neginfi : i32 to f32 + call @trunc_bf16(%neginff): (f32) -> () + + // CHECK-NEXT: 3.38953e+38 + %bigi = arith.constant 0x7f7fffff : i32 + %bigf = arith.bitcast %bigi : i32 to f32 + call @trunc_bf16(%bigf): (f32) -> () + + // CHECK-NEXT: -3.38953e+38 + %negbigi = arith.constant 0xff7fffff : i32 + %negbigf = arith.bitcast %negbigi : i32 to f32 + call @trunc_bf16(%negbigf): (f32) -> () + + // CHECK-NEXT: 1.625 + %exprolli = arith.constant 0x3fcfffff : i32 + %exprollf = arith.bitcast %exprolli : i32 to f32 + call @trunc_bf16(%exprollf): (f32) -> () + + // CHECK-NEXT: -1.625 + %exprollnegi = arith.constant 0xbfcfffff : i32 + %exprollnegf = arith.bitcast %exprollnegi : i32 to f32 + call @trunc_bf16(%exprollnegf): (f32) -> () + + return +}