diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -38,6 +38,9 @@ /// Add patterns to expand Arith ceil/floor division ops. void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); +/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts. +void populateExpandBFloat16Patterns(RewritePatternSet &patterns); + /// Add patterns to expand Arith ops. void populateArithExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -25,15 +26,13 @@ /// Create an integer or index constant. static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter) { - - auto elTy = getElementTypeOrSelf(type); - auto constantAttr = rewriter.getIntegerAttr(elTy, value); - - if (auto vecTy = llvm::dyn_cast(type)) + auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value); + if (auto shapedTy = dyn_cast(type)) { return rewriter.create( - loc, vecTy, DenseElementsAttr::get(vecTy, constantAttr)); + loc, DenseElementsAttr::get(shapedTy, attr)); + } - return rewriter.create(loc, constantAttr); + return rewriter.create(loc, attr); } namespace { @@ -187,6 +186,73 @@ } }; +struct BFloat16ExtFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + if (!operandETy.isBF16() || !resultETy.isF32()) { + return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32."); + } + + Type i16Ty = b.getI16Type(); + Type i32Ty = b.getI32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i16Ty = shapedTy.clone(i16Ty); + i32Ty = shapedTy.clone(i32Ty); + } + + Value bitcast = b.create(i16Ty, operand); + Value exti = b.create(i32Ty, bitcast); + + Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); + Value shl = b.create(exti, c16); + Value result = b.create(resultTy, shl); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct BFloat16TruncFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + if (!operandETy.isF32() || !resultETy.isBF16()) { + return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16."); + } + + Type i16Ty = b.getI16Type(); + Type i32Ty = b.getI32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i16Ty = shapedTy.clone(i16Ty); + i32Ty = shapedTy.clone(i32Ty); + } + + Value bitcast = b.create(i32Ty, operand); + Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); + Value shl = b.create(bitcast, c16); + Value trunc = b.create(i16Ty, shl); + Value result = b.create(resultTy, trunc); + + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsBase { void runOnOperation() override { @@ -204,6 +270,21 @@ arith::MaxFOp, arith::MinFOp >(); + + target.addDynamicallyLegalOp( + [](arith::ExtFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + return !(inETy.isBF16() && outETy.isF32()); + }); + + target.addDynamicallyLegalOp( + [](arith::TruncFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + return !(inETy.isF32() && outETy.isBF16()); + }); + // clang-format on if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -220,12 +301,19 @@ patterns.getContext()); } +void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); // clang-format off patterns.add< MaxMinFOpConverter, - MaxMinFOpConverter + MaxMinFOpConverter, + BFloat16ExtFOpConverter, + BFloat16TruncFOpConverter >(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -215,3 +215,67 @@ // CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32 // CHECK-NEXT: return %[[RESULT]] : f32 + +// ----- + +func.func @extf_bf16(%arg0 : bf16) -> f32 { + %0 = arith.extf %arg0 : bf16 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @extf_bf16 +// CHECK-SAME: %[[ARG0:.+]]: bf16 +// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : bf16 to i16 +// CHECK-DAG: %[[EXT:.+]] = arith.extui %[[BITCAST]] : i16 to i32 +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 +// CHECK-DAG: %[[SHLI:.+]] = arith.shli %[[EXT]], %[[C16]] +// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[SHLI]] : i32 to f32 +// CHECK: return %[[BITCAST]] + +// ----- + +func.func @extf_vector_bf16(%arg0 : vector<4xbf16>) -> vector<4xf32> { + %0 = arith.extf %arg0 : vector<4xbf16> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @extf_vector_bf16 +// CHECK-SAME: %[[ARG0:.+]]: vector<4xbf16> +// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : vector<4xbf16> to vector<4xi16> +// CHECK-DAG: %[[EXT:.+]] = arith.extui %[[BITCAST]] : vector<4xi16> to vector<4xi32> +// CHECK-DAG: %[[C16:.+]] = arith.constant dense<16> +// CHECK-DAG: %[[SHLI:.+]] = arith.shli %[[EXT]], %[[C16]] +// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[SHLI]] : vector<4xi32> to vector<4xf32> +// CHECK: return %[[BITCAST]] + +// ----- + +func.func @truncf_f32(%arg0 : f32) -> bf16 { + %0 = arith.truncf %arg0 : f32 to bf16 + return %0 : bf16 +} + +// CHECK-LABEL: @truncf_f32 +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 +// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : f32 to i32 +// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[BITCAST]], %[[C16]] +// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i16 +// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[TRUNC]] : i16 to bf16 +// CHECK: return %[[BITCAST]] : bf16 + +// ----- + +func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> { + %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xbf16> + return %0 : vector<4xbf16> +} + +// CHECK-LABEL: @truncf_vector_f32 +// CHECK-SAME: %[[ARG0:.+]]: vector<4xf32> +// CHECK-DAG: %[[C16:.+]] = arith.constant dense<16> +// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : vector<4xf32> to vector<4xi32> +// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[BITCAST]], %[[C16]] +// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : vector<4xi32> to vector<4xi16> +// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[TRUNC]] : vector<4xi16> to vector<4xbf16> +// CHECK: return %[[BITCAST]] : vector<4xbf16>