diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -15,14 +15,17 @@ name: A usage: input shape: affine_map<()[s0, s1, s2] -> (s0, s2)> + element_type_var: T - ! name: B usage: input shape: affine_map<()[s0, s1, s2] -> (s2, s1)> + element_type_var: T - ! name: C usage: output shape: affine_map<()[s0, s1, s2] -> (s0, s1)> + element_type_var: U indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> @@ -46,7 +49,15 @@ fn_name: mul operands: - !ScalarExpression - scalar_arg: A + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A - !ScalarExpression - scalar_arg: B + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: B diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -155,6 +155,45 @@ public: RegionBuilderHelper(Block &block) : block(block) {} + // Generates operations to cast the given operand to a specified type. + // If the cast cannot be performed, a warning will be issued and the + // operand returned as-is (which will presumably yield a verification + // issue downstream). + Value cast(Type toType, Value operand) { + OpBuilder builder = getBuilder(operand); + auto loc = operand.getLoc(); + + if (operand.getType() == toType) + return operand; + if (auto toIntType = toType.dyn_cast()) { + // If operand is floating point, cast directly to the int type. + if (operand.getType().isa()) + return builder.create(loc, toType, operand); + if (auto fromIntType = operand.getType().dyn_cast()) { + // Either sign extend or truncate. + if (toIntType.getWidth() > fromIntType.getWidth()) + return builder.create(loc, toType, operand); + else if (toIntType.getWidth() < fromIntType.getWidth()) + return builder.create(loc, toType, operand); + } + } else if (auto toFloatType = toType.dyn_cast()) { + // If operand is integer, cast directly to the float type. + // Note that it is unclear how to cast from BF16<->FP16. + if (operand.getType().isa()) + return builder.create(loc, toFloatType, operand); + if (auto fromFloatType = operand.getType().dyn_cast()) { + if (toFloatType.getWidth() > fromFloatType.getWidth()) + return builder.create(loc, toFloatType, operand); + else if (toFloatType.getWidth() < fromFloatType.getWidth()) + return builder.create(loc, toFloatType, operand); + } + } + + emitWarning(operand.getLoc()) << "could not cast operand of type " + << operand.getType() << " to " << toType; + return operand; + } + Value applyfn__add(Value lhs, Value rhs) { OpBuilder builder = getBuilder(lhs); if (isFloatingPoint(lhs)) diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -25,3 +25,99 @@ // CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 // CHECK-NEXT: linalg.yield %[[ADD]] : i32 // CHECK-NEXT: -> tensor<16x32xi32> + +// ----- +// Verifies floating point to integer cast. +func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { + %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) + outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16> + return %0: tensor<16x32xi16> +} + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: i16) +// CHECK-NEXT: %[[A_CAST:.+]] = fptosi %[[A_ARG]] : f32 to i16 +// CHECK-NEXT: %[[B_CAST:.+]] = fptosi %[[B_ARG]] : f32 to i16 +// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16 +// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16 +// CHECK-NEXT: linalg.yield %[[ADD]] : i16 +// CHECK-NEXT: -> tensor<16x32xi16> + +// ----- +// Verifies sign extension cast. +func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) + outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> + return %0: tensor<16x32xi32> +} + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32) +// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32 +// CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32 +// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32 +// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 +// CHECK-NEXT: linalg.yield %[[ADD]] : i32 +// CHECK-NEXT: -> tensor<16x32xi32> + +// ----- +// Somewhat non-sensical but checks integer truncation cast. +func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { + %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) + outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16> + return %0: tensor<16x32xi16> +} + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16) +// CHECK-NEXT: %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16 +// CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16 +// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16 +// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16 +// CHECK-NEXT: linalg.yield %[[ADD]] : i16 +// CHECK-NEXT: -> tensor<16x32xi16> + +// ----- +// Verifies integer to floating point cast. +func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32) +// CHECK-NEXT: %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32 +// CHECK-NEXT: %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32 +// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 +// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 +// CHECK-NEXT: linalg.yield %[[ADD]] : f32 +// CHECK-NEXT: -> tensor<16x32xf32> + +// ----- +// Verifies floating point extension cast. +func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) +// CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32 +// CHECK-NEXT: %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32 +// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 +// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 +// CHECK-NEXT: linalg.yield %[[ADD]] : f32 +// CHECK-NEXT: -> tensor<16x32xf32> + +// ----- +// Verifies floating point truncation. +func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) +// CHECK-NEXT: %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32 +// CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32 +// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 +// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 +// CHECK-NEXT: linalg.yield %[[ADD]] : f32 +// CHECK-NEXT: -> tensor<16x32xf32> diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -72,6 +72,7 @@ std::string name; LinalgTensorUsageDef usage; SerializedAffineMap shape; + std::string elementTypeVar; }; enum class LinalgIteratorTypeDef { @@ -92,9 +93,17 @@ std::vector operands; }; +struct ScalarSymbolicCast { + std::string typeVar; + // NOTE: This must be of arity 1, but to break the self-referential cycle, + // we use a heap allocated vector. + std::vector operands; +}; + struct ScalarExpression { - Optional scalarArg; - Optional scalarApply; + Optional arg; + Optional apply; + Optional symbolicCast; }; struct ScalarAssign { @@ -163,12 +172,15 @@ /// - `shape`: An AffineMap from all op symbols to the specific shape /// of this argument. Each shape must be normalized over the same list of /// symbols and have no dimension inputs. +/// - `element_type_var`: The symbolic type variable that binds to the scalar +/// element type of this TensorDef. template <> struct MappingTraits { static void mapping(IO &io, LinalgTensorDef &info) { io.mapRequired("name", info.name); io.mapRequired("usage", info.usage); io.mapRequired("shape", info.shape); + io.mapRequired("element_type_var", info.elementTypeVar); } }; @@ -230,11 +242,13 @@ /// - `scalar_arg`: Name of an argument to the op. /// - `scalar_apply`: Result of evaluating a named function (see /// `ScalarApply`). +/// - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere. template <> struct MappingTraits { static void mapping(IO &io, ScalarExpression &info) { - io.mapOptional("scalar_arg", info.scalarArg); - io.mapOptional("scalar_apply", info.scalarApply); + io.mapOptional("scalar_arg", info.arg); + io.mapOptional("scalar_apply", info.apply); + io.mapOptional("symbolic_cast", info.symbolicCast); } }; @@ -251,6 +265,14 @@ } }; +template <> +struct MappingTraits { + static void mapping(IO &io, ScalarSymbolicCast &info) { + io.mapRequired("type_var", info.typeVar); + io.mapRequired("operands", info.operands); + } +}; + /// Helper mapping which accesses an AffineMapAttr as a serialized string of /// the same. template <> @@ -348,6 +370,15 @@ return None; } +static Optional +findTypeVarArgIndex(StringRef typeVar, SmallVectorImpl &args) { + for (auto it : llvm::enumerate(args)) { + if (it.value().elementTypeVar == typeVar) + return it.index(); + } + return None; +} + static ScalarAssign * findAssignment(StringRef name, SmallVectorImpl &assignments) { for (auto &assign : assignments) { @@ -733,9 +764,9 @@ std::function(ScalarExpression &)> generateExpression = [&](ScalarExpression &expression) -> Optional { - if (expression.scalarArg) { - Optional argIndex = - findTensorDefArgIndex(*expression.scalarArg, args); + if (expression.arg) { + // Argument reference. + Optional argIndex = findTensorDefArgIndex(*expression.arg, args); if (!argIndex) { emitError(genContext.getLoc()) << "scalar argument not defined on the op: " << arg.name; @@ -743,10 +774,11 @@ } return std::string( llvm::formatv("block.getArgument({0})", *argIndex)); - } else if (expression.scalarApply) { + } else if (expression.apply) { + // Apply function. // Recursively generate operands. SmallVector operandCppValues; - for (ScalarExpression &operand : expression.scalarApply->operands) { + for (ScalarExpression &operand : expression.apply->operands) { auto operandCppValue = generateExpression(operand); if (!operandCppValue) return None; @@ -755,9 +787,41 @@ std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back( llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent, - expression.scalarApply->fnName, + expression.apply->fnName, interleaveToString(operandCppValues, ", "))); return cppIdent; + } else if (expression.symbolicCast) { + // Symbolic cast. + // Operands must be arity 1. + if (expression.symbolicCast->operands.size() != 1) { + emitError(genContext.getLoc()) + << "symbolic_cast operand arity must be 1"; + return None; + } + Optional operandCppValue = + generateExpression(expression.symbolicCast->operands[0]); + if (!operandCppValue) + return None; + + // Try to map the TypeVar to an arg index (which map to block arg + // indices), since we can just get that type directly. + // TODO: Handle free type variables which do not map to an argument. + Optional typeArgIndex = + findTypeVarArgIndex(expression.symbolicCast->typeVar, args); + if (!typeArgIndex) { + emitError(genContext.getLoc()) + << "type variable " << expression.symbolicCast->typeVar + << ", used in a symbolic cast must map to an argument but it " + << "does not"; + return None; + } + std::string typeCppValue = + llvm::formatv("block.getArgument({0}).getType()", *typeArgIndex); + std::string cppIdent = llvm::formatv("value{0}", ++localCounter); + stmts.push_back(llvm::formatv("Value {0} = helper.cast({1}, {2});", + cppIdent, typeCppValue, + *operandCppValue)); + return cppIdent; } else { emitError(genContext.getLoc()) << "unknown ScalarExpression type"; return None;