mirror of
https://github.com/Gericom/teak-llvm.git
synced 2025-06-19 11:35:51 -04:00
[mlir] Enable specifying verify on OpInterface
Summary: Add method in ODS to specify verification for operations implementing a OpInterface. Use this with infer type op interface to verify that the inferred type matches the return type and remove special case in TestPatterns. This could also have been achieved by using OpInterfaceMethod but verify seems pretty common and it is not an arbitrary method that just happened to be named verifyTrait, so having it be defined in special way seems appropriate/better documenting. Differential Revision: https://reviews.llvm.org/D73122
This commit is contained in:
parent
963f268186
commit
178562fb35
@ -400,6 +400,10 @@ def OpWithInferTypeInterfaceOp : Op<...
|
||||
[DeclareOpInterfaceMethods<MyInterface>]> { ... }
|
||||
```
|
||||
|
||||
A verification method can also be specified on the `OpInterface` by setting
|
||||
`verify`. Setting `verify` results in the generated trait having a `verifyTrait`
|
||||
method that is applied to all operations implementing the trait.
|
||||
|
||||
### Builder methods
|
||||
|
||||
For each operation, there are a few builders automatically generated based on
|
||||
|
@ -72,8 +72,6 @@ private:
|
||||
Attribute attr;
|
||||
};
|
||||
|
||||
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
|
||||
|
||||
namespace detail {
|
||||
// Helper function to infer return tensor returns types given element and shape
|
||||
// inference function.
|
||||
@ -89,8 +87,14 @@ LogicalResult inferReturnTensorTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferedReturnTypes);
|
||||
|
||||
/// Verifies that the inferred result types match the actual result types for
|
||||
/// the op. Precondition: op implements InferTypeOpInterface.
|
||||
LogicalResult verifyInferredResultTypes(Operation *op);
|
||||
} // namespace detail
|
||||
|
||||
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
|
||||
|
||||
namespace OpTrait {
|
||||
|
||||
/// Tensor type inference trait that constructs a tensor from the infered
|
||||
|
@ -60,6 +60,10 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
|
||||
}]
|
||||
>,
|
||||
];
|
||||
|
||||
let verify = [{
|
||||
return detail::verifyInferredResultTypes($_op);
|
||||
}];
|
||||
}
|
||||
|
||||
def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
|
||||
|
@ -1411,8 +1411,12 @@ def ins;
|
||||
// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
|
||||
// C++. The purpose to wrap around C++ symbol string with this class is to make
|
||||
// interfaces specified for ops in TableGen less alien and more integrated.
|
||||
class OpInterfaceTrait<string name> : NativeOpTrait<""> {
|
||||
class OpInterfaceTrait<string name, code verifyBody = [{}]> : NativeOpTrait<""> {
|
||||
let trait = name # "::Trait";
|
||||
|
||||
// Specify the body of the verification function. `$_op` will be replaced with
|
||||
// the operation being verified.
|
||||
code verify = verifyBody;
|
||||
}
|
||||
|
||||
// This class represents a single, optionally static, interface method.
|
||||
|
@ -86,6 +86,9 @@ public:
|
||||
// Return the description of this method if it has one.
|
||||
llvm::Optional<StringRef> getDescription() const;
|
||||
|
||||
// Return the verify method body if it has one.
|
||||
llvm::Optional<StringRef> getVerify() const;
|
||||
|
||||
private:
|
||||
// The TableGen definition of this interface.
|
||||
const llvm::Record *def;
|
||||
|
@ -45,3 +45,17 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
|
||||
SmallVector<Type, 4> inferedReturnTypes;
|
||||
auto retTypeFn = cast<InferTypeOpInterface>(op);
|
||||
if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(),
|
||||
op->getOperands(), op->getAttrs(),
|
||||
op->getRegions(), inferedReturnTypes)))
|
||||
return failure();
|
||||
SmallVector<Type, 4> resultTypes(op->getResultTypes());
|
||||
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
|
||||
return op->emitOpError(
|
||||
"inferred type incompatible with return type of operation");
|
||||
return success();
|
||||
}
|
||||
|
@ -85,3 +85,9 @@ llvm::Optional<StringRef> OpInterface::getDescription() const {
|
||||
auto value = def->getValueAsString("description");
|
||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
// Return the body for this method if it has one.
|
||||
llvm::Optional<StringRef> OpInterface::getVerify() const {
|
||||
auto value = def->getValueAsString("verify");
|
||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
@ -103,26 +103,6 @@ struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
// Verification check.
|
||||
// TODO: Move to ops that implement type infer interface.
|
||||
getFunction().walk([this](Operation *op) -> void {
|
||||
auto retTypeFn = dyn_cast<InferTypeOpInterface>(op);
|
||||
if (!retTypeFn)
|
||||
return;
|
||||
auto *context = &getContext();
|
||||
SmallVector<Type, 2> inferedReturnTypes;
|
||||
if (failed(retTypeFn.inferReturnTypes(
|
||||
context, op->getLoc(), op->getOperands(), op->getAttrs(),
|
||||
op->getRegions(), inferedReturnTypes)))
|
||||
return;
|
||||
SmallVector<Type, 1> resultTypes(op->getResultTypes());
|
||||
if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) {
|
||||
op->emitOpError(
|
||||
"inferred type incompatible with return type of operation");
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -23,7 +23,6 @@ func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testReturnTypeOpInterface
|
||||
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
|
||||
// expected-error@+1 {{incompatible with return type}}
|
||||
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
||||
@ -32,7 +31,6 @@ func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testReturnTypeOpInterface
|
||||
func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
|
||||
// expected-error@+1 {{operand type mismatch}}
|
||||
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
|
||||
|
@ -12,6 +12,7 @@
|
||||
|
||||
#include "DocGenUtilities.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/OpInterfaces.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@ -152,6 +153,12 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
|
||||
|
||||
// Insert the default implementation for any methods.
|
||||
for (auto &method : interface.getMethods()) {
|
||||
// Flag interface methods named verifyTrait.
|
||||
if (method.getName() == "verifyTrait")
|
||||
PrintFatalError(
|
||||
formatv("'verifyTrait' method cannot be specified as interface "
|
||||
"method for '{0}'; set 'verify' on OpInterfaceTrait instead",
|
||||
interfaceName));
|
||||
auto defaultImpl = method.getDefaultImplementation();
|
||||
if (!defaultImpl)
|
||||
continue;
|
||||
@ -162,6 +169,13 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
|
||||
os << " {\n" << defaultImpl.getValue() << " }\n";
|
||||
}
|
||||
|
||||
tblgen::FmtContext traitCtx;
|
||||
traitCtx.withOp("op");
|
||||
if (auto verify = interface.getVerify()) {
|
||||
os << " static LogicalResult verifyTrait(Operation* op) {\n"
|
||||
<< tblgen::tgfmt(*verify, &traitCtx) << "\n }\n";
|
||||
}
|
||||
|
||||
os << " };\n";
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user